package com.yeskery.nut.transaction;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Optional;
import java.util.Stack;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * 默认事务管理器
 * @author sprout
 * @version 1.0
 * 2022-08-27 14:03
 */
public class DefaultTransactionManager implements TransactionManager {

    /** 日志对象 */
    private static final Logger logger = Logger.getLogger(DefaultTransactionManager.class.getName());

    /** 事务持有对象栈 */
    private final ThreadLocal<Stack<TransactionHolder>> transactionHolderThreadLocal = new InheritableThreadLocal<>();

    /** 数据源对象 */
    private DataSource dataSource;

    /** 连接对象 */
    private Connection connection;

    /**
     * 构造默认事务管理器
     * @param dataSource 数据源对象
     */
    public DefaultTransactionManager(DataSource dataSource) {
        setDataSource(dataSource);
    }

    /**
     * 构造默认事务管理器
     * @param connection 连接对象
     */
    public DefaultTransactionManager(Connection connection) {
        setConnection(connection);
    }

    @Override
    public void setDataSource(DataSource dataSource) {
        this.dataSource = dataSource;
    }

    @Override
    public void setConnection(Connection connection) {
        this.connection = connection;
    }

    @Override
    public void startTransaction(TransactionIsolationLevel level, Propagation propagation) {
        Stack<TransactionHolder> stack = transactionHolderThreadLocal.get();
        if (stack == null) {
            stack = new Stack<>();
            transactionHolderThreadLocal.set(stack);
        }
        if (connection == null && dataSource == null) {
            throw new TransactionException("TransactionManager Not Set Data Source.");
        }
        if (propagation == Propagation.REQUIRED) {
            if (isExistTransaction(stack)) {
                addNestTransaction(stack);
            } else {
                TransactionHolder transactionHolder = createNewTransactionHolder(level, false, propagation);
                stack.push(transactionHolder);
            }
        } else if (propagation == Propagation.REQUIRES_NEW) {
            TransactionHolder transactionHolder = createNewTransactionHolder(level, false, propagation);
            stack.push(transactionHolder);
        } else if (propagation == Propagation.SUPPORTS) {
            if (isExistTransaction(stack)) {
                addNestTransaction(stack);
            }
        } else if (propagation == Propagation.MANDATORY) {
            if (isExistTransaction(stack)) {
                addNestTransaction(stack);
            } else {
                throw new TransactionException("Transaction Propagation [MANDATORY] Must Run On Transaction Environment.");
            }
        } else if (propagation == Propagation.NOT_SUPPORTED) {
            TransactionHolder transactionHolder = createNewTransactionHolder(level, true, propagation);
            stack.push(transactionHolder);
        } else if (propagation == Propagation.NEVER) {
            if (isExistTransaction(stack)) {
                throw new TransactionException("Transaction Propagation [MANDATORY] Must Run On No Transaction Environment.");
            }
        }
    }

    @Override
    public void commitTransaction() {
        TransactionHolder transactionHolder = getCurrentTransactionHolder(true);
        try {
            if (isExistTransaction(transactionHolder)) {
                int count = transactionHolder.getCount();
                if (count > 1) {
                    transactionHolder.setCount(count - 1);
                } else {
                    Transaction transaction = transactionHolder.getTransaction();

                    for (TransactionCallback beforeCommitCallback : transactionHolder.getBeforeCommitCallbacks()) {
                        try {
                            beforeCommitCallback.callback();
                        } catch (Exception e) {
                            logger.logp(Level.SEVERE, beforeCommitCallback.getClass().getName(), "callback",
                                    "Transaction Commit BeforeCommitCallback ["
                                            + beforeCommitCallback.getClass().getName() + "] Execute Fail.", e);
                        }
                    }

                    transaction.commit();
                    transaction.close();
                    transactionHolderThreadLocal.get().pop();

                    for (TransactionCallback afterCommitCallback : transactionHolder.getAfterCommitCallbacks()) {
                        try {
                            afterCommitCallback.callback();
                        } catch (Exception e) {
                            logger.logp(Level.SEVERE, afterCommitCallback.getClass().getName(), "callback",
                                    "Transaction Commit AfterCommitCallback ["
                                    + afterCommitCallback.getClass().getName() + "] Execute Fail.", e);
                        }
                    }
                }
            }
        } finally {
            for (TransactionCallback afterCompletedCallback : transactionHolder.getAfterCompletedCallbacks()) {
                try {
                    afterCompletedCallback.callback();
                } catch (Exception e) {
                    logger.logp(Level.SEVERE, afterCompletedCallback.getClass().getName(), "callback",
                            "Transaction Commit AfterCompletedCallback ["
                                    + afterCompletedCallback.getClass().getName() + "] Execute Fail.", e);
                }
            }
            clearTransactionHolder();
        }
    }

    @Override
    public void rollbackTransaction() {
        TransactionHolder transactionHolder = getCurrentTransactionHolder(true);
        try {
            if (isExistTransaction(transactionHolder)) {
                int count = transactionHolder.getCount();
                if (count > 1) {
                    transactionHolder.setCount(count - 1);
                } else {
                    Transaction transaction = transactionHolder.getTransaction();
                    transaction.rollback();
                    transaction.close();
                    transactionHolderThreadLocal.get().pop();

                    for (TransactionCallback afterRollbackCallback : transactionHolder.getAfterRollbackCallbacks()) {
                        try {
                            afterRollbackCallback.callback();
                        } catch (Exception e) {
                            logger.logp(Level.SEVERE, afterRollbackCallback.getClass().getName(), "callback",
                                    "Transaction Rollback AfterRollbackCallback ["
                                    + afterRollbackCallback.getClass().getName() + "] Execute Fail.", e);
                        }
                    }
                    for (TransactionCallback afterCompletedCallback : transactionHolder.getAfterCompletedCallbacks()) {
                        try {
                            afterCompletedCallback.callback();
                        } catch (Exception e) {
                            logger.logp(Level.SEVERE, afterCompletedCallback.getClass().getName(), "callback",
                                    "Transaction Rollback CompletedCallback ["
                                            + afterCompletedCallback.getClass().getName() + "] Execute Fail.", e);
                        }
                    }
                }
            }
        } finally {
            clearTransactionHolder();
        }
    }

    @Override
    public Transaction getCurrentTransaction() {
        TransactionHolder currentTransactionHolder = getCurrentTransactionHolder(true);
        return currentTransactionHolder == null ? null : currentTransactionHolder.getTransaction();
    }

    @Override
    public Connection getCurrentConnection() {
        TransactionHolder currentTransactionHolder = getCurrentTransactionHolder(true);
        return currentTransactionHolder == null ? null : currentTransactionHolder.getConnection();
    }

    @Override
    public Optional<Transaction> getCurrentTransactionOptional() {
        TransactionHolder currentTransactionHolder = getCurrentTransactionHolder(false);
        return Optional.ofNullable(currentTransactionHolder).map(TransactionHolder::getTransaction);
    }

    @Override
    public Optional<Connection> getCurrentConnectionOptional() {
        TransactionHolder currentTransactionHolder = getCurrentTransactionHolder(false);
        return Optional.ofNullable(currentTransactionHolder).map(TransactionHolder::getConnection);
    }

    @Override
    public void registerAfterCompletedTransactionCallback(TransactionCallback afterCompletedTransactionCallback) {
        TransactionHolder currentTransactionHolder = getCurrentTransactionHolder(true);
        if (currentTransactionHolder != null) {
            currentTransactionHolder.getAfterCompletedCallbacks().add(afterCompletedTransactionCallback);
        }
    }

    @Override
    public void registerBeforeCommitTransactionCallback(TransactionCallback beforeCommitTransactionCallback) {
        TransactionHolder currentTransactionHolder = getCurrentTransactionHolder(true);
        if (currentTransactionHolder != null) {
            currentTransactionHolder.getBeforeCommitCallbacks().add(beforeCommitTransactionCallback);
        }
    }

    @Override
    public void registerAfterCommitTransactionCallback(TransactionCallback afterCommitTransactionCallback) {
        TransactionHolder currentTransactionHolder = getCurrentTransactionHolder(true);
        if (currentTransactionHolder != null) {
            currentTransactionHolder.getAfterCommitCallbacks().add(afterCommitTransactionCallback);
        }
    }

    @Override
    public void registerAfterRollbackTransactionCallback(TransactionCallback afterRollbackTransactionCallback) {
        TransactionHolder currentTransactionHolder = getCurrentTransactionHolder(true);
        if (currentTransactionHolder != null) {
            currentTransactionHolder.getAfterRollbackCallbacks().add(afterRollbackTransactionCallback);
        }
    }

    @Override
    public void putTransactionResource(Object name, Object resource) {
        TransactionHolder currentTransactionHolder = getCurrentTransactionHolder(true);
        if (currentTransactionHolder != null) {
            currentTransactionHolder.getResourceMap().put(name, resource);
        }
    }

    @Override
    public Object getTransactionResource(Object name) {
        TransactionHolder currentTransactionHolder = getCurrentTransactionHolder(false);
        if (currentTransactionHolder == null) {
            return null;
        }
        return currentTransactionHolder.getResourceMap().get(name);
    }

    @Override
    public void removeTransactionResource(Object name) {
        TransactionHolder currentTransactionHolder = getCurrentTransactionHolder(false);
        if (currentTransactionHolder == null) {
            return;
        }
        currentTransactionHolder.getResourceMap().remove(name);
    }

    /**
     * 创建新的事务持有对象
     * @param level 事务隔离级别
     * @param autoCommit 是否自动提交
     * @param propagation 事务传播行为
     * @return 创建好的事务持有对象
     */
    protected TransactionHolder createNewTransactionHolder(TransactionIsolationLevel level, boolean autoCommit, Propagation propagation) {
        JdbcTransaction jdbcTransaction = dataSource == null
                ? new JdbcTransaction(connection)
                : new JdbcTransaction(dataSource, level, autoCommit);
        Connection connection = jdbcTransaction.getConnection();
        try {
            connection.setAutoCommit(autoCommit);
        } catch (SQLException e) {
            throw new TransactionException("Connection AutoCommit Close Fail.", e);
        }
        return new TransactionHolder(jdbcTransaction, connection, propagation);
    }

    /**
     * 获取当前的事务持有对象
     * @param throwable 是否抛出异常
     * @return 当前的事务持有对象
     */
    protected TransactionHolder getCurrentTransactionHolder(boolean throwable) {
        Stack<TransactionHolder> stack = transactionHolderThreadLocal.get();
        if (stack == null || stack.isEmpty()) {
            transactionHolderThreadLocal.remove();
            if (throwable) {
                throw new TransactionException("The Current Transaction Not Started.");
            }
            return null;
        }
        return stack.peek();
    }

    /**
     * 清理事务持有对象
     */
    protected void clearTransactionHolder() {
        Stack<TransactionHolder> stack = transactionHolderThreadLocal.get();
        if (stack == null || stack.isEmpty()) {
            transactionHolderThreadLocal.remove();
        }
    }

    /**
     * 添加子事务
     * @param stack 事务持有对象栈
     */
    private void addNestTransaction(Stack<TransactionHolder> stack) {
        TransactionHolder transactionHolder = stack.peek();
        int count = transactionHolder.getCount();
        transactionHolder.setCount(count + 1);
    }

    /**
     * 当前是否存在事务
     * @param stack 事务持有对象栈
     * @return 是否存在事务
     */
    private boolean isExistTransaction(Stack<TransactionHolder> stack) {
        if (stack.isEmpty()) {
            return false;
        }
        return isExistTransaction(stack.peek());
    }

    /**
     * 当前是否存在事务
     * @param transactionHolder 事务持有对象
     * @return 是否存在事务
     */
    private boolean isExistTransaction(TransactionHolder transactionHolder) {
        if (transactionHolder == null) {
            return false;
        }
        return transactionHolder.getPropagation() != Propagation.NOT_SUPPORTED
                && transactionHolder.getPropagation() != Propagation.NEVER;
    }
}
